{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from redunet import ReduNetVector\n",
    "import utils_example as ue\n",
    "import plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Data\n",
    "dataset = 1  # can be 1 to 8\n",
    "train_noise = 0.1\n",
    "test_noise = 0.1\n",
    "train_samples = 100\n",
    "test_samples = 100\n",
    "\n",
    "## Model\n",
    "num_layers = 200  # number of redunet layers\n",
    "eta = 0.5\n",
    "eps = 0.1\n",
    "lmbda = 200"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "X_train, y_train, num_classes = ue.generate_3d(dataset, train_noise, train_samples) # train\n",
    "X_test, y_test, num_classes = ue.generate_3d(dataset, test_noise, test_samples) # test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = ReduNetVector(num_classes, num_layers, X_train.shape[1], eta=eta, eps=eps, lmbda=lmbda)\n",
    "Z_train = net.init(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "ue.plot_loss_mcr(net.get_loss())\n",
    "ue.plot_3d(X_train, y_train, 'X_train') \n",
    "ue.plot_3d(Z_train, y_train, 'Z_train') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "Z_test = net(X_test).detach()\n",
    "ue.plot_3d(X_test, y_test, 'X_test')\n",
    "ue.plot_3d(X_test, y_test, 'Z_test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "redunet",
   "language": "python",
   "name": "redunet"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
